import logging
from concurrent.futures import Future
from typing import List, Optional

from nuplan.planning.metrics.metric_engine import MetricsEngine
from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario
from nuplan.planning.simulation.callback.abstract_callback import AbstractCallback
from nuplan.planning.simulation.history.simulation_history import SimulationHistory, SimulationHistorySample
from nuplan.planning.simulation.planner.abstract_planner import AbstractPlanner
from nuplan.planning.simulation.simulation_setup import SimulationSetup
from nuplan.planning.simulation.trajectory.abstract_trajectory import AbstractTrajectory
from nuplan.planning.utils.multithreading.worker_pool import Task, WorkerPool

logger = logging.getLogger(__name__)


def run_metric_engine(
    metric_engine: MetricsEngine, scenario: AbstractScenario, planner_name: str, history: SimulationHistory
) -> None:
    """
    Run the metric engine.
    """
    logger.debug("Starting metrics computation...")
    metric_files = metric_engine.compute(history, scenario=scenario, planner_name=planner_name)
    logger.debug("Finished metrics computation!")
    logger.debug("Saving metric statistics!")
    metric_engine.write_to_files(metric_files)
    logger.debug("Saved metrics!")


class MetricCallback(AbstractCallback):
    """Callback for computing metrics at the end of the simulation."""

    def __init__(self, metric_engine: MetricsEngine, worker_pool: Optional[WorkerPool] = None):
        """
        Build A metric callback.
        :param metric_engine: Metric Engine.
        """
        self._metric_engine = metric_engine

        self._pool = worker_pool
        self._futures: List[Future[None]] = []

    @property
    def metric_engine(self) -> MetricsEngine:
        """
        Returns metric engine.
        :return: metric engine
        """
        return self._metric_engine

    @property
    def futures(self) -> List[Future[None]]:
        """
        Returns a list of futures, eg. for the main process to block on.
        :return: any futures generated by running any part of the callback asynchronously.
        """
        return self._futures

    def on_initialization_start(self, setup: SimulationSetup, planner: AbstractPlanner) -> None:
        """Inherited, see superclass."""
        pass

    def on_initialization_end(self, setup: SimulationSetup, planner: AbstractPlanner) -> None:
        """Inherited, see superclass."""
        pass

    def on_step_start(self, setup: SimulationSetup, planner: AbstractPlanner) -> None:
        """Inherited, see superclass."""
        pass

    def on_step_end(self, setup: SimulationSetup, planner: AbstractPlanner, sample: SimulationHistorySample) -> None:
        """Inherited, see superclass."""
        pass

    def on_planner_start(self, setup: SimulationSetup, planner: AbstractPlanner) -> None:
        """Inherited, see superclass."""
        pass

    def on_planner_end(self, setup: SimulationSetup, planner: AbstractPlanner, trajectory: AbstractTrajectory) -> None:
        """Inherited, see superclass."""
        pass

    def on_simulation_start(self, setup: SimulationSetup) -> None:
        """Inherited, see superclass."""
        pass

    def on_simulation_end(self, setup: SimulationSetup, planner: AbstractPlanner, history: SimulationHistory) -> None:
        """Inherited, see superclass."""
        if self._pool is not None:
            self._futures = []
            self._futures.append(
                self._pool.submit(
                    Task(run_metric_engine, num_cpus=1, num_gpus=0),
                    metric_engine=self._metric_engine,
                    history=history,
                    scenario=setup.scenario,
                    planner_name=planner.name(),
                )
            )
        else:
            run_metric_engine(
                metric_engine=self._metric_engine, history=history, scenario=setup.scenario, planner_name=planner.name()
            )
